import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math

rng = np.random.default_rng(20260525)

def make_semantic_graph(M, avg_degree=4):
    adj = [set() for _ in range(M)]
    for i in range(M):
        j = (i + 1) % M
        adj[i].add(j)
        adj[j].add(i)

    extra_edges = max(0, int(M * avg_degree / 2) - M)
    a = rng.integers(0, M, size=extra_edges)
    b = rng.integers(0, M, size=extra_edges)

    for x, y in zip(a, b):
        if x != y:
            adj[int(x)].add(int(y))
            adj[int(y)].add(int(x))

    return [np.array(list(s), dtype=int) for s in adj]

def make_weights(M, heterogeneity):
    if heterogeneity <= 0:
        w = np.ones(M)
    else:
        w = rng.lognormal(mean=0.0, sigma=heterogeneity, size=M)
    return w / w.sum()

def graph_inference(known_direct, adj, weights, rho_c, semantic_pressure, max_rounds):
    M = len(known_direct)
    known = known_direct.copy()
    inferred_total = np.zeros(M, dtype=bool)

    for _ in range(max_rounds):
        newly = np.zeros(M, dtype=bool)

        for i in range(M):
            if known[i]:
                continue

            neigh = adj[i]
            if neigh.size == 0:
                continue

            known_neigh = neigh[known[neigh]]
            if known_neigh.size == 0:
                continue

            local_signal = semantic_pressure * (
                known_neigh.size + M * weights[known_neigh].sum()
            )

            p_infer = 1.0 - (1.0 - rho_c) ** local_signal
            p_infer = min(1.0, p_infer)

            if rng.random() < p_infer:
                newly[i] = True

        if not newly.any():
            break

        inferred_total |= newly
        known |= newly

    return inferred_total, known

def simulate_scenario(
    label,
    I_G,
    I0,
    lambda_factor,
    rho_c,
    avg_degree,
    heterogeneity,
    semantic_pressure,
    max_rounds,
    q_values,
    trials
):
    M = math.ceil(I_G / I0)

    adj = make_semantic_graph(M, avg_degree)
    weights = make_weights(M, heterogeneity)

    pool = np.repeat(np.arange(M), lambda_factor)
    N = len(pool)

    rows = []

    for q in q_values:
        k = int(round(q * N))

        wins = []
        complete_by_knowledge = []
        complete_by_entropy = []
        K_directs = []
        K_infs = []
        K_advs = []
        Hs = []
        missing_counts = []
        inferred_counts = []
        Pguesses = []

        for _ in range(trials):
            captured = rng.choice(pool, size=k, replace=False)

            known_direct = np.zeros(M, dtype=bool)
            known_direct[np.unique(captured)] = True

            inferred, known_total = graph_inference(
                known_direct,
                adj,
                weights,
                rho_c,
                semantic_pressure,
                max_rounds
            )

            missing = ~known_total

            K_direct = weights[known_direct].sum()
            K_inf = weights[inferred].sum()
            K_adv = weights[known_total].sum()

            H_res = weights[missing].sum() * I_G

            if H_res > 1024:
                P_guess = 0.0
            else:
                P_guess = 2.0 ** (-H_res)

            by_knowledge = known_total.all()
            by_entropy = rng.random() < P_guess

            win = by_knowledge or by_entropy

            wins.append(win)
            complete_by_knowledge.append(by_knowledge)
            complete_by_entropy.append(by_entropy)
            K_directs.append(K_direct)
            K_infs.append(K_inf)
            K_advs.append(K_adv)
            Hs.append(H_res)
            missing_counts.append(missing.sum())
            inferred_counts.append(inferred.sum())
            Pguesses.append(P_guess)

        rows.append({
            "label": label,
            "q": q,
            "I_G_bits": I_G,
            "I0_bits": I0,
            "lambda_factor": lambda_factor,
            "rho_c": rho_c,
            "avg_degree": avg_degree,
            "heterogeneity": heterogeneity,
            "semantic_pressure": semantic_pressure,
            "max_inference_rounds": max_rounds,
            "M_fragments": M,
            "N_total": N,
            "P_win": float(np.mean(wins)),
            "P_complete_by_knowledge": float(np.mean(complete_by_knowledge)),
            "P_complete_by_entropy": float(np.mean(complete_by_entropy)),
            "mean_K_direct": float(np.mean(K_directs)),
            "mean_K_inferential": float(np.mean(K_infs)),
            "mean_K_adv": float(np.mean(K_advs)),
            "p95_K_adv": float(np.quantile(K_advs, 0.95)),
            "mean_H_res": float(np.mean(Hs)),
            "p05_H_res": float(np.quantile(Hs, 0.05)),
            "mean_missing_count": float(np.mean(missing_counts)),
            "mean_inferred_count": float(np.mean(inferred_counts)),
            "mean_P_guess": float(np.mean(Pguesses)),
        })

    return pd.DataFrame(rows)

# ============================================================
# EXECUTION
# ============================================================

q_values = np.linspace(0.01, 0.99, 50)
trials = 350

scenarios = [
    (
        "Uniform low graph inference",
        1000,
        20,
        2,
        0.001,
        3,
        0.0,
        0.5,
        2,
    ),
    (
        "Weighted moderate graph inference",
        1000,
        20,
        2,
        0.003,
        4,
        0.8,
        1.0,
        3,
    ),
    (
        "High pulverization low inference",
        1000,
        5,
        2,
        0.001,
        3,
        0.8,
        0.5,
        2,
    ),
    (
        "High pulverization stress graph",
        1000,
        5,
        2,
        0.003,
        6,
        1.2,
        1.5,
        4,
    ),
]

results = pd.concat(
    [
        simulate_scenario(
            *scenario,
            q_values=q_values,
            trials=trials
        )
        for scenario in scenarios
    ],
    ignore_index=True
)

results.to_csv("cnvs_test8_graph_weighted_entropy_results.csv", index=False)
results.to_excel("cnvs_test8_graph_weighted_entropy_results.xlsx", index=False)

# ============================================================
# PLOT 1 — COMPLETE RECONSTRUCTION PROBABILITY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(
        g["q"],
        g["P_win"],
        linewidth=2.4,
        label=label
    )

plt.axvline(
    x=1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Probability of complete unauthorized reconstruction")
plt.title("CNVS Test 8 — Graph-Weighted Entropic Monte Carlo")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=8, loc="upper left")
plt.tight_layout()
plt.show()

# ============================================================
# PLOT 2 — WEIGHTED K_adv
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(
        g["q"],
        g["mean_K_adv"],
        linewidth=2.4,
        label=label
    )

plt.axvline(
    x=1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Mean weighted adversarial knowledge K_adv")
plt.title("CNVS Test 8 — Weighted K_adv")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=8, loc="upper left")
plt.tight_layout()
plt.show()

# ============================================================
# PLOT 3 — WEIGHTED RESIDUAL ENTROPY
# ============================================================

plt.figure(figsize=(12, 7))

for label, g in results.groupby("label"):
    plt.plot(
        g["q"],
        g["mean_H_res"],
        linewidth=2.4,
        label=label
    )

plt.axvline(
    x=1/3,
    color="black",
    linestyle="--",
    alpha=0.7,
    label="BFT reference line (visual only)"
)

plt.xlabel("Fraction of network physically compromised by attacker (q)")
plt.ylabel("Mean weighted residual entropy H_res (bits)")
plt.title("CNVS Test 8 — Weighted Residual Entropy")
plt.grid(True, linestyle=":", alpha=0.7)
plt.legend(fontsize=8, loc="upper right")
plt.tight_layout()
plt.show()

print(results.head())
print("\nSaved:")
print("cnvs_test8_graph_weighted_entropy_results.csv")
print("cnvs_test8_graph_weighted_entropy_results.xlsx")
